SageMath grading function

#Dec 24, 2019. Nasser: Ported original Maple grading function by 
#              Albert Rich to use with Sagemath. This is used to 
#              grade Fricas, Giac and Maxima results. 
#Dec 24, 2019. Nasser: Added 'exp_integral_e' and 'sng', 'sin_integral' 
#              'arctan2','floor','abs','log_integral' 
#June 4, 2022  Made default grade_annotation "none" instead of "" due 
#              issue later when reading the file. 
#July 14, 2022. Added ellipticF. This is until they fix sagemath, then remove it. 
 
from sage.all import * 
from sage.symbolic.operators import add_vararg, mul_vararg 
 
debug=False; 
 
def tree_size(expr): 
    r""" 
    Return the tree size of this expression. 
    """ 
    #print("Enter tree_size, expr is ",expr) 
 
    if expr not in SR: 
        # deal with lists, tuples, vectors 
        return 1 + sum(tree_size(a) for a in expr) 
    expr = SR(expr) 
    x, aa = expr.operator(), expr.operands() 
    if x is None: 
        return 1 
    else: 
        return 1 + sum(tree_size(a) for a in aa) 
 
def is_sqrt(expr): 
    if expr.operator() == operator.pow:   #isinstance(expr,Pow): 
        if expr.operands()[1]==1/2: #expr.args[1] == Rational(1,2): 
            if debug: print ("expr is sqrt") 
            return True 
        else: 
            return False 
    else: 
        return False 
 
def is_elementary_function(func): 
    #debug=False 
    m = func.name() in ['exp','log','ln', 
            'sin','cos','tan','cot','sec','csc', 
            'arcsin','arccos','arctan','arccot','arcsec','arccsc', 
            'sinh','cosh','tanh','coth','sech','csch', 
            'arcsinh','arccosh','arctanh','arccoth','arcsech','arccsch','sgn', 
        'arctan2','floor','abs' 
        ] 
    if debug: 
        if m: 
            print ("func ", func , " is elementary_function") 
        else: 
            print ("func ", func , " is NOT elementary_function") 
 
 
    return m 
 
def is_special_function(func): 
    #debug=False 
    if debug: 
        print ("type(func)=", type(func)) 
 
    m= func.name() in ['erf','erfc','erfi','fresnel_sin','fresnel_cos','Ei', 
           'Ei','Li','Si','sin_integral','Ci','cos_integral','Shi','sinh_integral' 
           'Chi','cosh_integral','gamma','log_gamma','psi,zeta', 
           'polylog','lambert_w','elliptic_f','elliptic_e','ellipticF', 
           'elliptic_pi','exp_integral_e','log_integral'] 
 
    if debug: 
        print ("m=",m) 
        if m: 
            print ("func ", func ," is special_function") 
        else: 
            print ("func ", func ," is NOT special_function") 
 
 
    return m 
 
 
def is_hypergeometric_function(func): 
    return func.name() in ['hypergeometric','hypergeometric_M','hypergeometric_U'] 
 
def is_appell_function(func): 
    return func.name() in ['hypergeometric']   #[appellf1] can't find this in sagemath 
 
def is_atom(expn): 
 
    #debug=False 
    if debug: 
         print ("Enter is_atom, expn=",expn) 
 
    if  not hasattr(expn, 'parent'): 
        return False 
 
 
    #thanks to answer at https://ask.sagemath.org/question/49179/what-is-sagemath-equivalent-to-atomic-type-in-maple/ 
    try: 
        if expn.parent() is SR: 
            return expn.operator() is None 
        if expn.parent() in (ZZ, QQ, AA, QQbar): 
            return expn in expn.parent() # Should always return True 
        if hasattr(expn.parent(),"base_ring") and hasattr(expn.parent(),"gens"): 
            return expn in expn.parent().base_ring() or expn in expn.parent().gens() 
 
        return False 
 
    except AttributeError as error: 
        print("Exception,AttributeError in  is_atom") 
        print ("cought exception" , type(error).__name__ ) 
        return False 
 
 
def expnType(expn): 
 
    if debug: 
        print (">>>>>Enter expnType, expn=", expn) 
        print (">>>>>is_atom(expn)=", is_atom(expn)) 
 
    if is_atom(expn): 
        return 1 
    elif type(expn)==list:   #isinstance(expn,list): 
        return max(map(expnType, expn))   #apply(max,map(ExpnType,expn)) 
    elif  is_sqrt(expn): 
        if  type(expn.operands()[0])==Rational: #type(isinstance(expn.args[0],Rational): 
            return 1 
        else: 
            return max(2,expnType(expn.operands()[0]))  #max(2,expnType(expn.args[0])) 
    elif expn.operator() == operator.pow:   #isinstance(expn,Pow) 
        if type(expn.operands()[1])==Integer:  #isinstance(expn.args[1],Integer) 
            return expnType(expn.operands()[0])   #expnType(expn.args[0]) 
        elif type(expn.operands()[1])==Rational:  #isinstance(expn.args[1],Rational) 
            if type(expn.operands()[0])==Rational: #isinstance(expn.args[0],Rational) 
                return 1 
            else: 
                return max(2,expnType(expn.operands()[0]))  #max(2,expnType(expn.args[0])) 
        else: 
            return max(3,expnType(expn.operands()[0]),expnType(expn.operands()[1])) #max(3,expnType(expn.operands()[0]),expnType(expn.operands()[1])) 
    elif expn.operator() == add_vararg or expn.operator() == mul_vararg: #isinstance(expn,Add) or isinstance(expn,Mul) 
        m1 = expnType(expn.operands()[0]) #expnType(expn.args[0]) 
        m2 = expnType(expn.operands()[1:]) #expnType(list(expn.args[1:])) 
        return max(m1,m2)  #max(ExpnType(op(1,expn)),max(ExpnType(rest(expn)))) 
    elif is_elementary_function(expn.operator()):  #is_elementary_function(expn.func) 
        return max(3,expnType(expn.operands()[0])) 
    elif is_special_function(expn.operator()): #is_special_function(expn.func) 
        m1 = max(map(expnType, expn.operands()))      #max(map(expnType, list(expn.args))) 
        return max(4,m1)   #max(4,m1) 
    elif is_hypergeometric_function(expn.operator()): #is_hypergeometric_function(expn.func) 
        m1 = max(map(expnType, expn.operands()))       #max(map(expnType, list(expn.args))) 
        return max(5,m1)   #max(5,m1) 
    elif is_appell_function(expn.operator()): 
        m1 = max(map(expnType, expn.operands()))       #max(map(expnType, list(expn.args))) 
        return max(6,m1)   #max(6,m1) 
    elif str(expn).find("Integral") != -1: #this will never happen, since it 
                #is checked before calling the grading function that is passed. 
                #but kept it here. 
        m1 = max(map(expnType, expn.operands()))       #max(map(expnType, list(expn.args))) 
        return max(8,m1)   #max(5,apply(max,map(ExpnType,[op(expn)]))) 
    else: 
        return 9 
 
#main function 
def grade_antiderivative(result,optimal): 
 
 
    if debug: 
        print ("Enter grade_antiderivative for sagemath") 
        print("Enter grade_antiderivative, result=",result) 
        print("Enter grade_antiderivative, optimal=",optimal) 
        print("type(anti)=",type(result)) 
        print("type(optimal)=",type(optimal)) 
 
    leaf_count_result  = tree_size(result) #leaf_count(result) 
    leaf_count_optimal = tree_size(optimal) #leaf_count(optimal) 
 
    #if debug: print ("leaf_count_result=", leaf_count_result, "leaf_count_optimal=",leaf_count_optimal) 
 
 
    expnType_result  = expnType(result) 
    expnType_optimal = expnType(optimal) 
 
    if debug: print ("expnType_result=", expnType_result, "expnType_optimal=",expnType_optimal) 
 
    if expnType_result <= expnType_optimal: 
        if result.has(I): 
            if optimal.has(I): #both result and optimal complex 
                if leaf_count_result <= 2*leaf_count_optimal: 
                    grade = "A" 
                    grade_annotation ="none" 
                else: 
                    grade = "B" 
                    grade_annotation ="Both result and optimal contain complex but leaf count of result is larger than twice the leaf count of optimal. "+str(leaf_count_result)+" vs. $2 ("+str(leaf_count_optimal)+") = "+ str(2*leaf_count_optimal)+"$." 
            else: #result contains complex but optimal is not 
                grade = "C" 
                grade_annotation ="Result contains complex when optimal does not." 
        else: # result do not contain complex, this assumes optimal do not as well 
            if leaf_count_result <= 2*leaf_count_optimal: 
                grade = "A" 
                grade_annotation ="none" 
            else: 
                grade = "B" 
                grade_annotation ="Leaf count of result is larger than twice the leaf count of optimal. "+str(leaf_count_result)+" vs. $2 ("+str(leaf_count_optimal)+") = "+ str(2*leaf_count_optimal)+"$." 
    else: 
        grade = "C" 
        grade_annotation ="Result contains higher order function than in optimal. Order "+str(expnType_result)+" vs. order "+str(expnType_optimal)+"." 
 
 
    print("Before returning. grade=",grade, " grade_annotation=",grade_annotation) 
 
    return grade, grade_annotation